{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize iterative minima fininding algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook outlines a way of using pythreejs as a way of getting input from the user with a `Picker`, and then triggering kernel-side computations based on this, which in turn will update the visualization.\n",
    "\n",
    "The example case here is an iterative algorithm that minimizes a function of two variables. The visualization will be a grid-based surface plot of this function. By double-clicking the surface, the user will start the algorithm. The results of the alogrithm will be live-plotted as a line on top of the surface."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Importing dependencies\n",
    "\n",
    "First, we import the dependencies we will need."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythreejs import *\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "from ipywidgets import HTML, Output, VBox, jslink\n",
    "\n",
    "view_width = 600\n",
    "view_height = 400"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define our iterative algorithm\n",
    "\n",
    "Here is a simple random descent method for finding local minima, that yields its iterative values. There are many issues with this algorithm, and as such it has a lot of potential for improvement! (left as an exercise for the reader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_minima(f, start=(0, 0), xlim=None, ylim=None):\n",
    "    rate = 0.1 # Learning rate\n",
    "    max_iters = 200 # maximum number of iterations\n",
    "    iters = 0 # iteration counter\n",
    "    \n",
    "    cur = np.array(start[:2])\n",
    "    previous_step_size = 1 #\n",
    "    cur_val = f(cur[0], cur[1]) \n",
    "    \n",
    "    while (iters < max_iters and\n",
    "           xlim[0] <= cur[0] <= xlim[1] and ylim[0] <= cur[1] <= ylim[1]):\n",
    "        iters = iters + 1\n",
    "        candidate = cur - rate * (np.random.rand(2) - 0.5)\n",
    "        candidate_val = f(candidate[0], candidate[1])\n",
    "        if candidate_val >= cur_val:\n",
    "            continue   # Bad guess, try again\n",
    "        prev = cur\n",
    "        cur = candidate\n",
    "        cur_val = candidate_val\n",
    "        previous_step_size = np.abs(cur - prev)\n",
    "        yield tuple(cur) + (cur_val,)\n",
    "\n",
    "    print(\"The local minimum occurs at\", cur)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define test function\n",
    "\n",
    "Give a function to test minima finder on. Here we simply use $ f(x, y) = x^2 + y^2 $. You can also try with the surface\n",
    "$ f(x, y) = x^2 - y^2 $ to see the effect of an instable surface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x, y):\n",
    "    return x ** 2 + y ** 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup figure\n",
    "\n",
    "Interrogate function on a grid in order to visualize. This uses numpy helper code to generate the grid, and evaluate the function. Note the two evalutations: One at the grid lattice points, and one in the center of each grid square!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nx, ny = (20, 20)  # grid resolution\n",
    "xmax = 1           # grid extent (+/-)\n",
    "x = np.linspace(-xmax, xmax, nx)\n",
    "y = np.linspace(-xmax, xmax, ny)\n",
    "step = x[1] - x[0]\n",
    "xx, yy = np.meshgrid(x, y)\n",
    "# Grid lattice values:\n",
    "grid_z = np.vectorize(f)(xx, yy)\n",
    "# Grid square center values:\n",
    "center_z = np.vectorize(f)(0.5 * step + xx[:-1,:-1], 0.5 * step + yy[:-1,:-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup code for 3D visualization with user point picking with mouse. Here we use the `SurfaceGeometry` and `SurfaceGrid` helper classes (not direct three.js classes)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Surface geometry:\n",
    "surf_g = SurfaceGeometry(z=list(grid_z.flat), \n",
    "                         width=2 * xmax,\n",
    "                         height=2 * xmax,\n",
    "                         width_segments=nx - 1,\n",
    "                         height_segments=ny - 1)\n",
    "\n",
    "# Surface material. Note that the map uses the center-evaluated function-values:\n",
    "surf = Mesh(geometry=surf_g,\n",
    "            material=MeshLambertMaterial(map=height_texture(center_z, 'YlGnBu_r')))\n",
    "\n",
    "# Grid-lines for the surface:\n",
    "surfgrid = SurfaceGrid(geometry=surf_g, material=LineBasicMaterial(color='black'),\n",
    "                       position=[0, 0, 1e-2])  # Avoid overlap by lifting grid slightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up scene:\n",
    "key_light = DirectionalLight(color='white', position=[3, 5, 1], intensity=0.4)\n",
    "c = PerspectiveCamera(position=[0, 3, 3], up=[0, 0, 1], aspect=view_width / view_height,\n",
    "                      children=[key_light])\n",
    "\n",
    "scene = Scene(children=[surf, c, surfgrid, AmbientLight(intensity=0.8)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now plot our figure. Note that initially, this will not have the interactive features, but we will add the interactivity below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "renderer = Renderer(camera=c, scene=scene,\n",
    "                    width=view_width, height=view_height,\n",
    "                    controls=[OrbitControls(controlling=c)])\n",
    "\n",
    "out = Output()        # An Output for displaying captured print statements\n",
    "box = VBox([renderer])\n",
    "display(box)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adding pickers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let us add a simple position tracker. This will find what point on the surface that the mouse is hovering over. We will represent this point with a pink sphere."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Picker object\n",
    "hover_picker = Picker(controlling=surf, event='mousemove')\n",
    "renderer.controls = renderer.controls + [hover_picker]\n",
    "\n",
    "# A sphere for representing the current point on the surface\n",
    "hover_point = Mesh(geometry=SphereGeometry(radius=0.05),\n",
    "                   material=MeshLambertMaterial(color='hotpink'))\n",
    "scene.add(hover_point)\n",
    "\n",
    "# Have sphere follow picker point:\n",
    "jslink((hover_point, 'position'), (hover_picker, 'point'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will observe the changes to the hover point, and display its coordinates in a label which we add to the figure above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coord_label = HTML()  # A label for showing hover picker coordinates\n",
    "\n",
    "def on_hover_change(change):\n",
    "    coord_label.value = 'Pink point at (%.3f, %.3f, %.3f)' % tuple(change['new'])\n",
    "\n",
    "on_hover_change({'new': hover_point.position})\n",
    "hover_picker.observe(on_hover_change, names=['point'])\n",
    "box.children = (coord_label,) + box.children"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we set up a picker for when the user double clikcs on the surface. This should trigger the execution and visualization of the alogrithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create our picker for the double-click event (\"dblclick\")\n",
    "click_picker = Picker(controlling=surf, event='dblclick')\n",
    "\n",
    "def on_click(change):\n",
    "    value = change['new']\n",
    "    with out:\n",
    "        print('Clicked on %s' % (value,))\n",
    "\n",
    "    # Add a red sphere on the picked point\n",
    "    point = Mesh(geometry=SphereGeometry(radius=0.05), \n",
    "                 material=MeshLambertMaterial(color='red'),\n",
    "                 position=value)\n",
    "    scene.add(point)\n",
    "    \n",
    "    # Plot solution as a red line, this will start out empty\n",
    "    points = [value]\n",
    "    line = Line2(geometry=LineGeometry(positions=points), material=LineMaterial(color='red', linewidth=2))\n",
    "    scene.add(line)\n",
    "    with out:  # Pick up any print statements in the algorithm\n",
    "        for pt in find_minima(f, value, [-xmax, xmax], [-xmax, xmax]):\n",
    "            # For each point, update the line:\n",
    "            pt = list(pt)\n",
    "            pt[2] += 1e-2   # offset to clear surface\n",
    "            line.geometry = LineGeometry(positions=np.vstack([line.geometry.positions, pt]))\n",
    "\n",
    "\n",
    "# When the point selected by the picker changes, trigger our function:\n",
    "click_picker.observe(on_click, names=['point'])\n",
    "\n",
    "# Update figure:\n",
    "renderer.controls = renderer.controls + [click_picker]\n",
    "box.children = box.children + (out,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Final note: This notebook tries to explain the visualization code for a specific scenario. If you are more interested in understanding how different iterative minimization algorithms work, you should extract this visualization code to an external function that you can import. Then you can use it to:\n",
    "- Visualize the output of your new and improved algorithm while implementing it.\n",
    "- Visually compare different algorithms for various starting points, and various function landscapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
